# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Auto-tuning a Convolutional Network for NVIDIA GPU
==================================================
**Author**: `Lianmin Zheng <https://github.com/merrymercy>`_, `Eddie Yan <https://github.com/eqy/>`_

Auto-tuning for specific devices and workloads is critical for getting the
best performance. This is a tutorial on how to tune a whole convolutional
network for NVIDIA GPU.

The operator implementation for NVIDIA GPU in TVM is written in template form.
The template has many tunable knobs (tile factor, unrolling, etc).
We will tune all convolution and depthwise convolution operators
in the neural network. After tuning, we produce a log file which stores
the best knob values for all required operators. When the TVM compiler compiles
these operators, it will query this log file to get the best knob values.

We also released pre-tuned parameters for some NVIDIA GPUs. You can go to
`NVIDIA GPU Benchmark <https://github.com/apache/tvm/wiki/Benchmark#nvidia-gpu>`_
to see the results.

Note that this tutorial will not run on Windows or recent versions of macOS. To
get it to run, you will need to wrap the body of this tutorial in a :code:`if
__name__ == "__main__":` block.
"""

import os

import numpy as np

import tvm
from tvm import relay, autotvm
from tvm.autotvm.tuner.raft_ga_tuner import raftGATuner
from tvm.autotvm.tuner.raft_xgb_tuner import raftXGBTuner
from tvm.autotvm.tuner.raft_suiji_tuner import raftSuijiTuner
import tvm.relay.testing
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner,fastestTuner_v3,fastestTuner_v2
import tvm.contrib.graph_executor as runtime
import logging
#logging.getLogger('autotvm').setLevel(logging.DEBUG)
def get_network(name, batch_size):
    """Get the symbol definition and random weight of a network"""
    input_shape = (batch_size, 3, 224, 224)
    output_shape = (batch_size, 1000)

    if "resnet" in name:
        n_layer = int(name.split("-")[1])
        mod, params = relay.testing.resnet.get_workload(
            num_layers=n_layer, batch_size=batch_size, dtype=dtype
        )
    elif "vgg" in name:
        n_layer = int(name.split("-")[1])
        mod, params = relay.testing.vgg.get_workload(
            num_layers=n_layer, batch_size=batch_size, dtype=dtype
        )
    elif name == "mobilenet":
        mod, params = relay.testing.mobilenet.get_workload(batch_size=batch_size, dtype=dtype)
    elif name == "squeezenet_v1.1":
        mod, params = relay.testing.squeezenet.get_workload(
            batch_size=batch_size, version="1.1", dtype=dtype
        )
    elif name == "inception_v3":
        input_shape = (batch_size, 3, 299, 299)
        mod, params = relay.testing.inception_v3.get_workload(batch_size=batch_size, dtype=dtype)
    elif name == "mxnet":
        # an example for mxnet model
        from mxnet.gluon.model_zoo.vision import get_model

        block = get_model("resnet18_v1", pretrained=True)
        mod, params = relay.frontend.from_mxnet(block, shape={"data": input_shape}, dtype=dtype)
        net = mod["main"]
        net = relay.Function(
            net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs
        )
        mod = tvm.IRModule.from_expr(net)
    else:
        raise ValueError("Unsupported network: " + name)

    return mod, params, input_shape, output_shape



target = tvm.target.iluvatar()

network = os.environ['network']
tuner = os.environ['tuner']

log_file = "%s.log" % network
dtype = "float32"

tuning_option = {
    "log_filename": log_file,
    "tuner": tuner,
    "n_trial": 1000,
    "early_stopping": 200,
    "measure_option": autotvm.measure_option(
        builder=autotvm.LocalBuilder(timeout=10),
        runner=autotvm.LocalRunner(number=20, repeat=3, timeout=4, min_repeat_ms=150),
        # runner=autotvm.RPCRunner(
        #     'BI-V100',  # change the device key to your key
        #     '127.0.0.1', 9191,
        #     number=20, repeat=3, timeout=4, min_repeat_ms=150)
    ),
}

# raft 的参数

# tuning_option = {
#     "log_filename": log_file,
#     "tuner": tuner,
#     "n_trial": 10,
#     "early_stopping": 150,
#     "measure_option": autotvm.measure_option(
#         builder=autotvm.LocalBuilder(timeout=10),
#         runner=autotvm.LocalRunner(number=20, repeat=3, timeout=4, min_repeat_ms=150),
#     ),
# }


def tune_tasks(
    tasks,
    measure_option,
    tuner="fastest",
    n_trial=1,
    early_stopping=None,
    log_filename="tuning.log",
    use_transfer_learning=True,
    try_winograd = False
):
    # print("调优任务如下：")
    # for i in range(len(tasks)):
    #     print(tasks[i].name)   
    # if not try_winograd:
    #     for i in range(len(tasks)):
    #         try:  # try winograd template
    #             tsk = autotvm.task.create('conv2d_nchw.cuda', tasks[i].args,
    #                                       tasks[i].target, tasks[i].target_host)
    #             tasks[i] = tsk
    #         except Exception:
    #             pass
    tasks_cp = []
    for i in range(len(tasks)):
        if tasks[i].name =="conv2d_nchw.cuda":
            tasks_cp.append(tasks[i])
    tasks = tasks_cp

    print("调优任务如下：")
    for i in range(len(tasks)):
        print(tasks[i].name)   
        print(tasks[i].args)
        print(tasks[i].target)
        print(tasks[i].target_host)
        print("==========")
    # create tmp log file
    tmp_log_file = log_filename + ".tmp"
    # if os.path.exists(tmp_log_file):
    #     os.remove(tmp_log_file)



    for i, tsk in enumerate(reversed(tasks)):
        prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
        if i == 0:
           continue

        # create tuner
        if tuner == "xgb" or tuner == "xgb-rank":
            tuner_obj = XGBTuner(tsk, loss_type="rank")
        elif tuner == "ga":
            tuner_obj = GATuner(tsk, pop_size=100)
        elif tuner == "random":
            tuner_obj = RandomTuner(tsk)
        elif tuner == "gridsearch":
            tuner_obj = GridSearchTuner(tsk)
        elif tuner == "fastest":
            tuner_obj = fastestTuner_v2(tsk)
        elif tuner == "raft_ga":
            tuner_obj = raftGATuner(tsk)
        elif tuner == "raft_xgb":
            tuner_obj = raftXGBTuner(tsk)
        elif tuner == "raft_suiji":
            tuner_obj = raftSuijiTuner(tsk)            
        else:
            raise ValueError("Invalid tuner: " + tuner)

        use_transfer_learning =False
        if use_transfer_learning:
            if os.path.isfile(tmp_log_file):
                tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))

        # do tuning
        tsk_trial = min(n_trial, len(tsk.config_space))
        print("调度空间：%d"%(len(tsk.config_space)))
        tuner_obj.tune(
            n_trial=tsk_trial,
            early_stopping=early_stopping,
            measure_option=measure_option,
            callbacks=[
                autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
                autotvm.callback.log_to_file(tmp_log_file),
            ],
        )

    # pick best records to a cache file
    autotvm.record.pick_best(tmp_log_file, log_filename)
    #os.remove(tmp_log_file)

def tune_and_evaluate(tuning_opt):
    # extract workloads from relay program
    print("Extract tasks...")
    mod, params, input_shape, out_shape = get_network(network, batch_size=1)

    #mod, params, input_shape, out_shape = 

    tasks = autotvm.task.extract_from_program(
        mod["main"], target=target, params=params, ops=(relay.op.get("nn.conv2d"),)
    )

    # run tuning tasks
    print("Tuning...")
    tune_tasks(tasks, **tuning_opt)

    # compile kernels with history best records
    with autotvm.apply_history_best(log_file):
        print("Compile...")
        with tvm.transform.PassContext(opt_level=3):
            lib = relay.build_module.build(mod, target=target, params=params)

        # load parameters
        dev = tvm.device(str(target), 0)
        module = runtime.GraphModule(lib["default"](dev))
        data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
        module.set_input("data", data_tvm)

        # evaluate
        print("Evaluate inference time cost...")
        ftimer = module.module.time_evaluator("run", dev, number=1, repeat=600)
        prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
        print(
            "Mean inference time (std dev): %.2f ms (%.2f ms)"
            % (np.mean(prof_res), np.std(prof_res))
        )


# We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following line to run it by yourself.

tune_and_evaluate(tuning_option)

